# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
mpl.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文
import matplotlib.pyplot as plt
import os
import shutil


batch_size = 100
z_dim = 512


def save_imgs(images, batch_num):
    if isinstance(images, list):
        images = np.array(images)
    if not os.path.exists('./gen'):
        os.makedirs('./gen')
    # else:
    #     shutil.rmtree('./gen')
    for i, image in enumerate(images):
        plt.axis('off')
        plt.imshow(image, cmap='gray')
        plt.savefig('./gen/{}_{}.png'.format(batch_num, i))


sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph('./dcgan-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')

for batch_num in range(50):
    print(batch_num)
    n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
    gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
    gen_imgs = (gen_imgs + 1) / 2
    imgs = [img[:, :, 0] for img in gen_imgs]
    save_imgs(imgs, batch_num)

